from re import A
import time
from scipy.spatial import distance
import numpy as np
from heapq import *

class HARP(object):

    def __init__(self,start,goal,critical_regions, sampling_distribution,collision_fn,discretizer,abstraction,n = 50, m = 50, max_time = 20,simulator = None,hl_regions = None):
        self.start = discretizer.convert_sample(start)
        # self.start = discretizer.convert_sample(start[:2])
        # self.goal = discretizer.convert_sample(goal[:2])
        self.goal = discretizer.convert_sample(goal)

        self.init_hl_state = abstraction.get_abstract_state(self.start).id
        self.goal_hl_state = abstraction.get_abstract_state(self.goal).id
        self.hl_regions = hl_regions
        self.collision_fn = collision_fn
        self.critical_regions = critical_regions
        self.sampling_distribution = sampling_distribution
        self.abstraction = abstraction
        self.n = n 
        self.m = m
        self.roadmap = None
        self.roadmapedges = None
        self.selected_graph = 0
        self.goal_radius = 0.20
        self.max_time = max_time
        self.starttime = time.time()
        self.currentgraph = 0
        self.currentstate = "build graphs"
        self.simulator = simulator
        self.k = 0
        self.discretizer = discretizer
        self.llimits = self.discretizer.robot.get_dof_lower_limits()
        self.ulimits = self.discretizer.robot.get_dof_upper_limits()
        self.simulator.env.plot(self.start,[0,0,0])
        self.simulator.env.plot(self.goal,[0,0,0])
        self.build_roadmap()


    def normalize(self,q):
        normalized = []
        for i in range(len(q)): 
            normalized.append((1.0 / float(self.ulimits[i] - self.llimits[i])) * (q[i] - self.ulimits[i]) + 1 )
        return normalized

    def normalized_dist(self,a,b,flag=True):
        if flag:
            n_a = self.normalize(a)
        else:
            n_a = a
        n_b = self.normalize(b)
        return np.linalg.norm(np.asarray(n_a) - np.asarray(n_b))

    def sample_critical_regions(self):
        return self.critical_regions.sample()
        
    def sample_abstract_states(self):
        #TODO: implement high-level state sampling to reduce the state space.
        return self.sampling_distribution.sample()

    def collides(self,q):
        return self.collision_fn(q)
    
    def get_new_sample(self):
        if self.k < 100:
            self.k += 1
            return self.sample_critical_regions()
        else:
            return self.sample_abstract_states()

    def build_roadmap(self):
        self.roadmap = []
        self.roadmapedges = []

        while len(self.roadmap) < self.n:
            while True:
                q = self.sample_critical_regions()
                if not self.collides(q):
                    self.simulator.env.plot(q,[0,0,1])
                    break
                else:
                    self.simulator.env.plot(q,[1,0,0])
            self.roadmap.append([q])
            self.roadmapedges.append({0:[]})
            
        while (len(self.roadmap)-self.n) < self.m:
            while True:
                q = self.sample_abstract_states()
                if not self.collides(q):
                    break
            self.roadmap.append([q])
            self.roadmapedges.append({0:[]})


        if len(self.start) > 0 and len(self.goal) > 0:
            self.roadmap.append([self.start])
            self.roadmapedges.append({0:[]})
            self.roadmap.append([self.goal])
            self.roadmapedges.append({0:[]})

        self.buildtime = self.max_time

        while True:
            if (time.time()-self.starttime) <= self.buildtime: # 1 seconds to build in RM mode. 60 in llp
                rand = self.get_new_sample()
                # self.simulator.env.plot(rand,[0,0,1])
                self.roadmap[self.currentgraph], self.roadmapedges[self.currentgraph], status, new = self.extend(self.roadmap[self.currentgraph], self.roadmapedges[self.currentgraph], rand)

                if status != 'trapped':
                    connected = self.connectN(new)

                    if connected:
                        # print str(time.time()-self.starttime) + ',' + str(len(self.roadmap[0])) # comment put for llp
                        self.currentstate = 'connected graphs'
                        break
                    pass

                if self.currentstate == 'build graphs':
                    self.swapN()
            else:
                print("Time up", str(time.time()-self.starttime))
                self.currentstate = 'connected graphs'
                numstates = 0
                for m in self.roadmap:
                    numstates += len(m)
                break
    
    def swapN(self):
        if self.currentgraph >= len(self.roadmap)-1:
            self.currentgraph = 0
        else:
            self.currentgraph += 1

    def p2p_regions(self, path, start_config, goal_config):
        start_hl_state = self.abstraction.get_abstract_state(start_config)
        goal_hl_state = self.abstraction.get_abstract_state(goal_config)
        switch_point = None
        for idx, waypoint in enumerate(path):
            waypoint_hl_state = self.abstraction.get_abstract_state(waypoint)
            # if (not waypoint_hl_state.id  == start_hl_state.id) and \
            #     (not waypoint_hl_state.id  == goal_hl_state.id):
            if waypoint_hl_state.id not in self.hl_regions:
                return False, None
            if switch_point is None and\
                waypoint_hl_state.id  == goal_hl_state.id:
                switch_point = idx
        return True, switch_point

    def connect(self, V, E, q):
        status = 'advanced'

        # loop until reached or collision
        while status == 'advanced':
            V, E, status, new = self.extend(V, E, q)

        if status == 'reached':
            # add G=(V,E) to q's graph
            i_q = len(self.roadmap[self.currentgraph])-1 
            self.roadmap[self.currentgraph] = self.roadmap[self.currentgraph] + V
            for i, e in enumerate(E):
                adj = []
                for n in E[e]:
                    adj.append((n[0]+i_q+1,n[1]))
                self.roadmapedges[self.currentgraph][i_q+1+i] = adj

            i_new = len(self.roadmap[self.currentgraph])-1
            self.roadmapedges[self.currentgraph][i_q].append((i_new,new))
            self.roadmapedges[self.currentgraph][i_new].append((i_q,q))
            # self.trace.append(self.env.drawlinestrip(points=array([[q[0], q[1], self.traceheight],[new[0], new[1], self.traceheight]]), linewidth=0.5, colors=array(self.color), drawstyle=1)
        return status 

    def connectN(self, q):
        delete = []
        for i in range(len(self.roadmap)):
            if i != self.currentgraph:
                connected = self.connect(self.roadmap[i], self.roadmapedges[i], q)
                if connected == 'reached': 
                    delete.append(i)
            if (time.time()-self.starttime) >= self.buildtime:
                break

        # delete merged graphs
        self.roadmap = [self.roadmap[i] for i in range(len(self.roadmap)) if not i in delete]
        self.roadmapedges = [self.roadmapedges[i] for i in range(len(self.roadmapedges)) if not i in delete]
        # print len(self.roadmap)
        # if self.constrained and len(self.roadmap) < 50:
        if self.check_if_connected():
                return True

        return len(self.roadmap) == 1

    def check_if_connected(self):
        for i in range(len(self.roadmap)):
            if self.start in self.roadmap[i] and self.goal in self.roadmap[i]:
                self.selected_graph = i
                return True
        return False

    def extend(self, V, E, q):
        try:
            # i_near = distance.cdist([q], V,self.normalized_dist).argmin()
            i_near = distance.cdist([q], V).argmin()
        except ValueError:
            pass
        near = V[i_near]
        new = self.compound_step(near, q)
        if not self.abstraction.check_env_collision(new) and self.collides(new) == False and self.abstraction.get_abstract_state(new).id in self.hl_regions:  
            V.append(new)
            E[len(V)-1] = []
            E[len(V)-1].append((i_near,near)) 
            E[i_near].append((len(V)-1,new))
            if self.goal_zone_collision(new, q):
                return V, E, 'reached', new
            else:
                return V, E, 'advanced', new
        else:
            return V, E, 'trapped', None 

    def compound_step(self,p1,p2):
        a = []
        for i in range(len(p1)):
            a = a + self.step_from_to([p1[i]],[p2[i]],0.15)
        return a
    
    def step_from_to(self,p1,p2,distance):
        #https://github.com/motion-planning/rrt-algorithms/blob/master/src/rrt/rrt_base.py
        if self.dist(p1,p2) <= distance:
        # if self.normalized_dist(p1,p2) <= distance:
            return p2
        else:
            a = np.array(p1)
            b = np.array(p2)
            ab = b-a  # difference between start and goal

            zero_vector = np.zeros(len(ab))

            ba_length = self.dist(zero_vector, ab)  # get length of vector ab
            unit_vector = np.fromiter((i / ba_length for i in ab), np.float, len(ab))
            # scale vector to desired length
            scaled_vector = np.fromiter((i * distance for i in unit_vector), np.float, len(unit_vector))
            steered_point = np.add(a, scaled_vector)  # add scaled vector to starting location for final point

            return list(steered_point)

    def goal_zone_collision(self,p1,p2):
        # if self.normalized_dist(p1, p2) <= self.goal_radius:
        if self.dist(p1, p2) <= self.goal_radius:
            return True
        else:
            return False

    def dist(sefl,a,b):
        a = np.array(a)
        b = np.array(b)
        return np.linalg.norm(a-b)


    def search(self):
        ''' dijkstra's '''
        q = []
        dist = {}
        prev = {}
        
        for i in range(len(self.roadmap[self.selected_graph])):
            dist[i] = float("inf")
            prev[i] = None

        dist[self.start[0]] = 0
        heappush(q, (0,self.start))

        while q:
            currdist, near = heappop(q)

            for n in self.roadmapedges[self.selected_graph][near[0]]:
                # alt = currdist + self.dist(near[1], n[1])
                alt = currdist + self.dist(near[1], n[1])
                if alt < dist[n[0]]:
                    dist[n[0]] = alt
                    prev[n[0]] = near
                    heappush(q, (alt, n))

        # collect solution path through backtracking from goal using prev
        solutiontrace = []
        temp = self.goal
        if prev[temp[0]]:
            while temp:
                solutiontrace.append(temp[1])
                temp = prev[temp[0]]

        return solutiontrace
    
    def simplify(self, path):
        simple = [path[0]]
        for i, pt in enumerate(path[1:]):
            if not np.allclose(np.array(pt), np.array(simple[-1]), rtol=1e-3):
                simple.append(pt)
        return self.smooth(simple)

    def smooth(self,path, threshold = 0.2):
        smoothened = []
        i = 0 
        while i < len(path) - 1:
            p1 = path[i]
            last_reachable_state = i
            j = i + 1 
            while j < len(path): 
                p2 = path[j]
                # if self.dist(p1,p2) <= threshold:
                if self.dist(p1,p2) <= threshold:
                    last_reachable_state = j
                j+=1
            smoothened.append(path[last_reachable_state])
            if i == last_reachable_state:
                i += 1
            else:
                i = last_reachable_state
        return smoothened



    def get_mp(self):
        s = self.start 
        g = self.goal
        switchpt = 0
        if self.currentstate == 'connected graphs':
            i_s = distance.cdist([s], self.roadmap[self.selected_graph]).argmin()
            i_g = distance.cdist([g], self.roadmap[self.selected_graph]).argmin()
            self.start = (i_s,s)
            self.goal = (i_g,g)

            path = self.search()
            if len(path) > 0:
                path = path[::-1]
                direct, switchpt = self.p2p_regions(path, s, g)
                if direct:
                    path = self.simplify(path)
                    path = [self.start[1]] + path + [self.goal[1]]
                    return True, path, switchpt
                else:
                    return False, [], switchpt
            else:
                return False, [], switchpt
        else:
            return False, [], switchpt
